import torch
import numpy as np


def expandTest():
    a = np.array([[1, 2], [3, 4]])
    t = torch.from_numpy(a)
    print(a)
    print(t)
    print(t.shape)
    new_arr = np.expand_dims(t, axis=0)  # numpy扩充维度
    t = t[None]  # 扩充一个维度，相当于t = t[None, :]省略了:

    print(t)
    print(t.shape)
    print(new_arr.shape)
